{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "'想要有直升机\\n想要和你飞到宇宙去\\n想要和你融化在一起\\n融化在宇宙里\\n我每天每天每'"
     },
     "metadata": {},
     "execution_count": 2
    }
   ],
   "source": [
    "import torch\n",
    "import random\n",
    "import zipfile\n",
    "\n",
    "with zipfile.ZipFile(\"../data/jaychou_lyrics.txt.zip\")as zin:\n",
    "    with zin.open(\"jaychou_lyrics.txt\") as f:\n",
    "        corpus_chars = f.read().decode(\"utf-8\")\n",
    "\n",
    "corpus_chars[:40]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 周杰伦歌词数据集\n",
    "# 这个数据集有６万多个字符，为了打印方便，将换行符换成空格\n",
    "corpus_chars = corpus_chars.replace(\"\\n\",\" \").replace(\"\\r\", \" \")\n",
    "# corpus_chars = corpus_chars[0:10000]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 建立字符索引\n",
    "将每个字符映射为一个从０开始的整数，称为索引，方便后续的数据处理。　为了得到索引，将数据集中的所有不同字符取出来，然后将其逐一映射构建字典。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "1027"
     },
     "metadata": {},
     "execution_count": 7
    }
   ],
   "source": [
    "idx_to_char = list(set(corpus_chars))\n",
    "char_to_idx = dict([(char,i) for i,char in enumerate(idx_to_char)])\n",
    "vocab_size = len(char_to_idx)\n",
    "vocab_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "chars: 想 要 有 直 升 机   想 要 和 你 飞 到 宇 宙 去   想 要 和\nindices: [356, 466, 720, 696, 248, 184, 219, 356, 466, 227, 505, 199, 973, 88, 1002, 509, 219, 356, 466, 227]\n"
    }
   ],
   "source": [
    "# 将数据集中的字符转换为索引\n",
    "corpus_indices = [char_to_idx[char] for char in corpus_chars]\n",
    "sample = corpus_indices[:20]\n",
    "print(\"chars:\",\" \".join([idx_to_char[idx] for idx in sample]))\n",
    "print(\"indices:\",sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_iter_random(corpus_indices,batch_size,num_steps,device=None):\n",
    "    # -1是因为输出的索引ｘ市相应的输入的索引ｙ+1\n",
    "    num_examples = (len(corpus_indices) - 1) // num_steps\n",
    "    epoch_size = num_examples // batch_size\n",
    "\n",
    "    examples_indices = list(range(num_examples))\n",
    "    random.shuffle(examples_indices)\n",
    "\n",
    "    def _data(pos):\n",
    "        return corpus_indices[pos:pos + num_steps]\n",
    "    \n",
    "    if device is None:\n",
    "        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    \n",
    "    for i in range(epoch_size):\n",
    "        i = i * batch_size\n",
    "        batch_indices = examples_indices[i: i + batch_size]\n",
    "        x = [_data(j * num_steps) for j in batch_indices]\n",
    "        y = [_data(j * num_steps + 1) for j in batch_indices]\n",
    "        yield torch.tensor(x,dtype=torch.float32,device=device),torch.tensor(y,dtype=torch.float32,device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "x: tensor([[18., 19., 20., 21., 22., 23.],\n        [ 0.,  1.,  2.,  3.,  4.,  5.]], device='cuda:0')\ny: tensor([[19., 20., 21., 22., 23., 24.],\n        [ 1.,  2.,  3.,  4.,  5.,  6.]], device='cuda:0')\nx: tensor([[ 6.,  7.,  8.,  9., 10., 11.],\n        [12., 13., 14., 15., 16., 17.]], device='cuda:0')\ny: tensor([[ 7.,  8.,  9., 10., 11., 12.],\n        [13., 14., 15., 16., 17., 18.]], device='cuda:0')\n"
    }
   ],
   "source": [
    "my_seq = list(range(30))\n",
    "# batch_size = 2,每次取２个样本\n",
    "# steps = 6，每个样本是由６个样本组成的片段\n",
    "for x,y in data_iter_random(my_seq,batch_size=2,num_steps=6): \n",
    "    print(\"x:\",x)\n",
    "    print(\"y:\",y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_iter_consecutive(corpus_indices,batch_size,num_stpes,device=None):\n",
    "    if device is None:\n",
    "        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    \n",
    "    corpus_indices = torch.tensor(corpus_indices,dtype=torch.float32,device=device)\n",
    "    data_len = len(corpus_indices)\n",
    "    batch_len = data_len // batch_size\n",
    "    \n",
    "    indices = corpus_indices[0:batch_size * batch_len].view(batch_size,batch_len)\n",
    "    \n",
    "    epoch_size = (batch_len - 1) // num_stpes\n",
    "\n",
    "    for i in range(epoch_size):\n",
    "        i = i * num_stpes\n",
    "        x = indices[:,i : i + num_stpes]\n",
    "        y = indices[:,i + 1 : i + num_stpes + 1]\n",
    "        yield x,y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "x: tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],\n        [15., 16., 17., 18., 19., 20.]], device='cuda:0')\ny: tensor([[ 1.,  2.,  3.,  4.,  5.,  6.],\n        [16., 17., 18., 19., 20., 21.]], device='cuda:0')\nx: tensor([[ 6.,  7.,  8.,  9., 10., 11.],\n        [21., 22., 23., 24., 25., 26.]], device='cuda:0')\ny: tensor([[ 7.,  8.,  9., 10., 11., 12.],\n        [22., 23., 24., 25., 26., 27.]], device='cuda:0')\n"
    }
   ],
   "source": [
    "for x,y in data_iter_consecutive(my_seq,batch_size=2,num_stpes=6):\n",
    "    print(\"x:\",x)\n",
    "    print(\"y:\",y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "[1687, 2052, 1784, 372, 1111, 843, 48, 1687, 2052, 1359, 29, 423, 533, 2283, 2090, 1349, 48, 1687, 2052, 1359, 29, 102, 1421, 2437, 1275, 1826, 48, 102, 1421, 2437, 2283, 2090, 411, 48, 78, 1346, 1128, 1346, 1128, 1346]\n2582\n"
    }
   ],
   "source": [
    "import d2lzh as d2l\n",
    "\n",
    "corpus_indices,char_to_idx,idx_to_char,vocab_size = d2l.load_data_jay_lyrics()\n",
    "\n",
    "print(corpus_indices[:40])\n",
    "print(vocab_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "想\n"
    }
   ],
   "source": [
    "print(idx_to_char[corpus_indices[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python37764bitpytorchnotebookconda6e7a8693df0d4d92aca09d521275d23a",
   "display_name": "Python 3.7.7 64-bit ('pytorch_notebook': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}